"""recur_cnn.py
Recurrent cnn models.
"""

import torch.nn as nn


class CNN(nn.Module):
    def __init__(self, width=64, depth=4, in_channels=3, num_classes=10, dataset="CIFAR10"):
        super().__init__()
        self.dataset = dataset
        self.width = width
        self.depth = depth
        self.in_channels = in_channels
        self.first_layers = nn.Sequential(nn.Conv2d(self.in_channels, int(self.width/2),
                                                    kernel_size=3, stride=1),
                                          nn.ReLU(),
                                          nn.Conv2d(int(self.width/2), self.width, kernel_size=3,
                                                    stride=1),
                                          nn.ReLU())
        self.middle_layers = nn.Sequential(*[nn.Sequential(nn.Conv2d(self.width, self.width,
                                                                     kernel_size=3, stride=1,
                                                                     padding=1), nn.ReLU())
                                             for _ in range(depth - 3)])
        self.last_layers = nn.Sequential(nn.MaxPool2d(3),
                                         nn.Conv2d(self.width, 2*self.width, kernel_size=3,
                                                   stride=1),
                                         nn.ReLU(),
                                         nn.MaxPool2d(3))

        if self.dataset.upper() == "CIFAR10":
            self.linear = nn.Linear(8 * width, num_classes)
        else:
            self.linear = nn.Linear(72 * width, num_classes)

    def forward(self, x):
        out = self.first_layers(x)
        out = self.middle_layers(out)
        out = self.last_layers(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

class CNN_MNIST(nn.Module):
    def __init__(self, width=16, depth=4, in_channels=1, num_classes=10, dataset="CIFAR10"):
        super().__init__()
        self.dataset = dataset
        self.width = width
        self.depth = depth
        self.in_channels = in_channels
        self.first_layers = nn.Sequential(nn.Conv2d(self.in_channels, int(self.width),
                                                    kernel_size=3, stride=1),
                                          nn.ReLU())
        self.middle_layers = nn.Sequential(*[nn.Sequential(nn.Conv2d(self.width, self.width,
                                                                     kernel_size=3, stride=1,
                                                                     padding=1), nn.ReLU())
                                             for _ in range(depth - 3)])
        self.last_layers = nn.Sequential(nn.MaxPool2d(3),
                                         nn.Conv2d(self.width, 2*self.width, kernel_size=3,
                                                   stride=1),
                                         nn.ReLU(),
                                         nn.MaxPool2d(3))

        self.linear = nn.Linear(8 * width, num_classes)

    def forward(self, x):
        out = self.first_layers(x)
        out = self.middle_layers(out)
        out = self.last_layers(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

class CNN_EMNIST(nn.Module):
    def __init__(self, width=128, depth=4, in_channels=1, num_classes=10, dataset="CIFAR10"):
        super().__init__()
        self.dataset = dataset
        self.width = width
        self.depth = depth
        self.in_channels = in_channels
        self.first_layers = nn.Sequential(nn.Conv2d(self.in_channels, int(self.width/2),
                                                    kernel_size=3, stride=1),
                                          nn.ReLU(),
                                          nn.Conv2d(int(self.width/2), self.width, kernel_size=3,
                                                    stride=1),
                                          nn.ReLU())
        self.middle_layers = nn.Sequential(*[nn.Sequential(nn.Conv2d(self.width, self.width,
                                                                     kernel_size=3, stride=1,
                                                                     padding=1), nn.ReLU())
                                             for _ in range(depth - 3)])
        self.last_layers = nn.Sequential(nn.MaxPool2d(3),
                                         nn.Conv2d(self.width, 2*self.width, kernel_size=3,
                                                   stride=1),
                                         nn.ReLU(),
                                         nn.MaxPool2d(3))

        self.linear = nn.Linear(8 * width, num_classes)

    def forward(self, x):
        out = self.first_layers(x)
        out = self.middle_layers(out)
        out = self.last_layers(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def cnn_4(num_outputs=10):
    return CNN(num_classes=num_outputs, depth=4)


def cnn_5(num_outputs=10):
    return CNN(num_classes=num_outputs, depth=5)


def cnn_6(num_outputs=10):
    return CNN(num_classes=num_outputs, depth=6)


def cnn_7(num_outputs=10):
    return CNN(num_classes=num_outputs, depth=7)


def cnn_8(num_outputs=10):
    return CNN(num_classes=num_outputs, depth=8)

def cnn_4_mnist(num_outputs=10):
    return CNN_MNIST(num_classes=num_outputs, depth=4)


def cnn_5_mnist(num_outputs=10):
    return CNN_MNIST(num_classes=num_outputs, depth=5)


def cnn_6_mnist(num_outputs=10):
    return CNN_MNIST(num_classes=num_outputs, depth=6)


def cnn_7_mnist(num_outputs=10):
    return CNN_MNIST(num_classes=num_outputs, depth=7)


def cnn_8_mnist(num_outputs=10):
    return CNN_MNIST(num_classes=num_outputs, depth=8)


def cnn_4_emnist(num_outputs=47):
    return CNN_EMNIST(num_classes=num_outputs, depth=4)


def cnn_5_emnist(num_outputs=47):
    return CNN_EMNIST(num_classes=num_outputs, depth=5)


def cnn_6_emnist(num_outputs=47):
    return CNN_EMNIST(num_classes=num_outputs, depth=6)


def cnn_7_emnist(num_outputs=47):
    return CNN_EMNIST(num_classes=num_outputs, depth=7)


def cnn_8_emnist(num_outputs=47):
    return CNN_EMNIST(num_classes=num_outputs, depth=8)


def cnn_4_tinyimagenet(num_outputs=200):
    return CNN(num_classes=num_outputs, depth=4, dataset="TINYIMAGENET")


def cnn_5_tinyimagenet(num_outputs=200):
    return CNN(num_classes=num_outputs, depth=5, dataset="TINYIMAGENET")


def cnn_6_tinyimagenet(num_outputs=200):
    return CNN(num_classes=num_outputs, depth=6, dataset="TINYIMAGENET")


def cnn_7_tinyimagenet(num_outputs=200):
    return CNN(num_classes=num_outputs, depth=7, dataset="TINYIMAGENET")


def cnn_8_tinyimagenet(num_outputs=200):
    return CNN(num_classes=num_outputs, depth=8, dataset="TINYIMAGENET")
